'''
@Time    : 2022/3/21 15:33
@Author  : Fu Junyu
@Site    : www.fujunyu.cn
@File    : RGCN-Train.py
@Software: PyCharm
'''


# load graph data
import torch
from dgl._deprecate.graph import DGLGraph
from dgl.contrib.data import load_data
import torch.nn.functional as F
import numpy as np

from test_Rgcn_DGL import Model

data = load_data(dataset='aifb')
num_nodes = data.num_nodes
num_rels = data.num_rels
num_classes = data.num_classes
labels = data.labels
train_idx = data.train_idx
# split training and validation set
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]

# edge type and normalization factor
edge_type = torch.from_numpy(data.edge_type)
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1)

labels = torch.from_numpy(labels).view(-1)

# configurations
n_hidden = 16  # number of hidden units
n_bases = -1  # use number of relations as number of bases
n_hidden_layers = 0  # use 1 input layer, 1 output layer, no hidden layer
n_epochs = 25  # epochs to train
lr = 0.01  # learning rate
l2norm = 0  # L2 norm coefficient

# create graph
g = DGLGraph((data.edge_src, data.edge_dst))
print(g.edata.update({'rel_type': edge_type, 'norm': edge_norm}))


# create model
model = Model(len(g),
              n_hidden,
              num_classes,
              num_rels,
              num_bases=n_bases,
              num_hidden_layers=n_hidden_layers)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)

print("start training...")
# model.train()
# for epoch in range(n_epochs):
#     optimizer.zero_grad()
#     logits = model.forward(g)
#     loss = F.cross_entropy(logits[train_idx], labels[train_idx])
#     loss.backward()
#
#     optimizer.step()
#
#     train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
#     train_acc = train_acc.item() / len(train_idx)
#     val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
#     val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx])
#     val_acc = val_acc.item() / len(val_idx)
#     print("Epoch {:05d} | ".format(epoch) +
#           "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
#               train_acc, loss.item()) +
#           "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
#               val_acc, val_loss.item()))

print(g.ndata)